#include "testmemload.h"

#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include <stddef.h>
#include <intrin.h>
#include <tchar.h>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <malloc.h>

#include "memload/memload.h"

#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif

#if defined (__cplusplus) || defined (c_plusplus)
extern "C" {
#endif

#if defined(_DEBUG)
    #if defined(WIN64) || defined(_WIN64) || defined(__WIN64__)
        #define CSDLL_LIBRARY_NAME          _T("csdlld")
        #define CSDLL_LIBRARY_FULLNAME      _T("csdlld.dll")
    #elif defined(WIN32) || defined(_WIN32) || defined(__WIN32__)
        #define CSDLL_LIBRARY_NAME          _T("csdlld")
        #define CSDLL_LIBRARY_FULLNAME      _T("csdlld.dll")
    #else
        #error "Platform not supported"
    #endif
#else
    #if defined(WIN64) || defined(_WIN64) || defined(__WIN64__)
        #define CSDLL_LIBRARY_NAME          _T("csdll")
        #define CSDLL_LIBRARY_FULLNAME      _T("csdll.dll")
    #elif defined(WIN32) || defined(_WIN32) || defined(__WIN32__)
        #define CSDLL_LIBRARY_NAME          _T("csdll")
        #define CSDLL_LIBRARY_FULLNAME      _T("csdll.dll")
    #else
        #error "Platform not supported"
    #endif
#endif

#define CSDLL_LIBRARY_ADDSUM    "addsum"

typedef int (*fnAddsumProc)(int, int);

void PrintCWD(void)
{
    TCHAR path[1024];
    GetModuleFileName(NULL, path, sizeof(path)/sizeof(path[0]));
    _tprintf(_T("GetModuleFileName: %s\n"), path);
    GetCurrentDirectory(sizeof(path)/sizeof(path[0]), path);
    _tprintf(_T("GetCurrentDirectory: %s\n"), path);
}

int LoadFromFile(void)
{
    fnAddsumProc fnAddsum;
    HRSRC resourceInfo;
    DWORD resourceSize;
    LPVOID resourceData;
    TCHAR buffer[100];

    PrintCWD();
    _tprintf(_T("%s\n"), _T("LoadFromFile ......"));
    HINSTANCE handle = LoadLibrary(CSDLL_LIBRARY_FULLNAME);
    if (handle == NULL) {
        OutputLastError(_T("load library failed"));
        _tprintf(_T("load library failed: %s\n"), CSDLL_LIBRARY_FULLNAME);
        return 1;
    }

    fnAddsum = (fnAddsumProc)GetProcAddress(handle, CSDLL_LIBRARY_ADDSUM);
    if (fnAddsum == NULL) {
        OutputLastError(_T("get proc address failed"));
        _tprintf(_T("get proc address failed.\n"));
        return 2;
    }
    _tprintf(_T("get proc address: %s - 0x%p\n"), CSDLL_LIBRARY_ADDSUM, fnAddsum);
    _tprintf(_T("From file proc result: %d\n"), fnAddsum(1, 2));

    resourceInfo = FindResource(handle, MAKEINTRESOURCE(VS_VERSION_INFO), RT_VERSION);
    _tprintf(_T("FindResource returned 0x%p\n"), resourceInfo);

    resourceSize = SizeofResource(handle, resourceInfo);
    resourceData = LoadResource(handle, resourceInfo);
    _tprintf(_T("Resource data: %ld bytes at 0x%p\n"), resourceSize, resourceData);

    LoadString(handle, 1, buffer, sizeof(buffer));
    _tprintf(_T("String1: %s\n"), buffer);

    LoadString(handle, 20, buffer, sizeof(buffer));
    _tprintf(_T("String2: %s\n"), buffer);

    FreeLibrary(handle);
    return 0;
}

void* ReadLibrary(size_t* pSize)
{
    size_t read;
    void* result;
    FILE* fp;

    fp = _tfopen(CSDLL_LIBRARY_FULLNAME, _T("rb"));
    if (fp == NULL)
    {
        _tprintf(_T("Can't open DLL file \"%s\"."), CSDLL_LIBRARY_FULLNAME);
        return NULL;
    }

    fseek(fp, 0, SEEK_END);
    *pSize = static_cast<size_t>(ftell(fp));
    if (*pSize == 0)
    {
        fclose(fp);
        return NULL;
    }

    result = (unsigned char *)malloc(*pSize);
    if (result == NULL)
    {
        return NULL;
    }

    fseek(fp, 0, SEEK_SET);
    read = fread(result, 1, *pSize, fp);
    fclose(fp);
    if (read != *pSize)
    {
        free(result);
        return NULL;
    }

    return result;
}

int LoadFromMemory(void)
{
    void *data;
    size_t size;
    HMEMORYMODULE handle;
    fnAddsumProc fnAddsum;
    HMEMORYRSRC resourceInfo;
    DWORD resourceSize;
    LPVOID resourceData;
    TCHAR buffer[100];
    int iret = 0;

    PrintCWD();
    _tprintf(_T("%s\n"), _T("LoadFromMemory ......"));
    data = ReadLibrary(&size);
    if (data == NULL)
    {
        iret = 1;
        goto exit;
    }

    handle = MemoryLoadLibrary(data, size);
    if (handle == NULL)
    {
        OutputLastError(_T("Can't load library from memory"));
        _tprintf(_T("Can't load library from memory.\n"));
        iret = 2;
        goto exit;
    }

    fnAddsum = (fnAddsumProc)MemoryGetProcAddress(handle, CSDLL_LIBRARY_ADDSUM);
    if (fnAddsum == NULL)
    {
        OutputLastError(_T("Can't get proc address"));
        _tprintf(_T("Can't get proc address.\n"));
        iret = 3;
        goto exit;
    }
    _tprintf(_T("get proc address: %s - 0x%p\n"), CSDLL_LIBRARY_ADDSUM, fnAddsum);
    _tprintf(_T("From memory proc result: %d\n"), fnAddsum(1, 2));

    resourceInfo = MemoryFindResource(handle, MAKEINTRESOURCE(VS_VERSION_INFO), RT_VERSION);
    _tprintf(_T("MemoryFindResource returned 0x%p\n"), resourceInfo);

    resourceSize = MemorySizeofResource(handle, resourceInfo);
    resourceData = MemoryLoadResource(handle, resourceInfo);
    _tprintf(_T("Memory resource data: %ld bytes at 0x%p\n"), resourceSize, resourceData);

    MemoryLoadString(handle, 1, buffer, sizeof(buffer));
    _tprintf(_T("String1: %s\n"), buffer);

    MemoryLoadString(handle, 20, buffer, sizeof(buffer));
    _tprintf(_T("String2: %s\n"), buffer);

    MemoryFreeLibrary(handle);
    iret = 0;

exit:
    if (data != NULL)
    {
        free(data);
        data = NULL;
    }
    return iret;
}

#define MAX_CALLS   (20)

struct CallList {
    int current_alloc_call, current_free_call;
    CustomAllocFunc alloc_calls[MAX_CALLS];
    CustomFreeFunc free_calls[MAX_CALLS];
};


LPVOID MemoryFailingAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
{
    UNREFERENCED_PARAMETER(address);
    UNREFERENCED_PARAMETER(size);
    UNREFERENCED_PARAMETER(allocationType);
    UNREFERENCED_PARAMETER(protect);
    UNREFERENCED_PARAMETER(userdata);
    return NULL;
}

LPVOID MemoryMockAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
{
    CallList* calls = (CallList*)userdata;
    CustomAllocFunc current_func = calls->alloc_calls[calls->current_alloc_call++];
    assert(current_func != NULL);
    return current_func(address, size, allocationType, protect, NULL);
}

BOOL MemoryMockFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata)
{
    CallList* calls = (CallList*)userdata;
    CustomFreeFunc current_func = calls->free_calls[calls->current_free_call++];
    assert(current_func != NULL);
    return current_func(lpAddress, dwSize, dwFreeType, NULL);
}

void InitFuncs(void** funcs, va_list args) {
    for (int i = 0; ; i++) {
        assert(i < MAX_CALLS);
        funcs[i] = va_arg(args, void*);
        if (funcs[i] == NULL) break;
    }
}

void InitAllocFuncs(CallList* calls, ...) {
    va_list args;
    va_start(args, calls);
    InitFuncs((void**)calls->alloc_calls, args);
    va_end(args);
    calls->current_alloc_call = 0;
}

void InitFreeFuncs(CallList* calls, ...) {
    va_list args;
    va_start(args, calls);
    InitFuncs((void**)calls->free_calls, args);
    va_end(args);
    calls->current_free_call = 0;
}

void InitFreeFunc(CallList* calls, CustomFreeFunc freeFunc) {
    for (int i = 0; i < MAX_CALLS; i++) {
        calls->free_calls[i] = freeFunc;
    }
    calls->current_free_call = 0;
}

void TestFailingAllocation(void *data, size_t size) {
    CallList expected_calls;
    HMEMORYMODULE handle;

    InitAllocFuncs(&expected_calls, MemoryFailingAlloc, MemoryFailingAlloc, NULL);
    InitFreeFuncs(&expected_calls, NULL);

    handle = MemoryLoadLibraryEx(
        data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
        MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);

    assert(handle == NULL);
    assert(GetLastError() == ERROR_OUTOFMEMORY);
    assert(expected_calls.current_free_call == 0);

    MemoryFreeLibrary(handle);
    assert(expected_calls.current_free_call == 0);
}

void TestCleanupAfterFailingAllocation(void *data, size_t size) {
    CallList expected_calls;
    HMEMORYMODULE handle;
    int free_calls_after_loading;

    InitAllocFuncs(&expected_calls,
        MemoryDefaultAlloc,
        MemoryDefaultAlloc,
        MemoryDefaultAlloc,
        MemoryDefaultAlloc,
        MemoryFailingAlloc,
        NULL);
    InitFreeFuncs(&expected_calls, MemoryDefaultFree, NULL);

    handle = MemoryLoadLibraryEx(
        data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
        MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);

    free_calls_after_loading = expected_calls.current_free_call;

    MemoryFreeLibrary(handle);
    assert(expected_calls.current_free_call == free_calls_after_loading);
}

void TestFreeAfterDefaultAlloc(void *data, size_t size) {
    CallList expected_calls;
    HMEMORYMODULE handle;
    int free_calls_after_loading;

    // Note: free might get called internally multiple times
    InitFreeFunc(&expected_calls, MemoryDefaultFree);

    handle = MemoryLoadLibraryEx(
        data, size, MemoryDefaultAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
        MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);

    assert(handle != NULL);
    free_calls_after_loading = expected_calls.current_free_call;

    MemoryFreeLibrary(handle);
    assert(expected_calls.current_free_call == free_calls_after_loading + 1);
}

#ifdef _WIN64

LPVOID MemoryAllocHigh(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
{
    int* counter = static_cast<int*>(userdata);
    if (*counter == 0) {
        // Make sure the image gets loaded to an address above 32bit.
        uintptr_t offset = 0x10000000000;
        address = (LPVOID) ((uintptr_t) address + offset);
    }
    (*counter)++;
    return MemoryDefaultAlloc(address, size, allocationType, protect, NULL);
}

void TestAllocHighMemory(void *data, size_t size) {
    HMEMORYMODULE handle;
    int counter = 0;
    fnAddsumProc fnAddsum;
    HMEMORYRSRC resourceInfo;
    DWORD resourceSize;
    LPVOID resourceData;
    TCHAR buffer[100];

    handle = MemoryLoadLibraryEx(
        data, size, MemoryAllocHigh, MemoryDefaultFree, MemoryDefaultLoadLibrary,
        MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &counter);
    if (handle == NULL)
    {
        OutputLastError(_T("Can't load library from memory"));
        _tprintf(_T("Can't load library from memory.\n"));
    }
    assert(handle != NULL);

    fnAddsum = (fnAddsumProc)MemoryGetProcAddress(handle, CSDLL_LIBRARY_ADDSUM);
    if (fnAddsum == NULL)
    {
        OutputLastError(_T("Can't get proc address"));
        _tprintf(_T("Can't get proc address.\n"));
    }
    _tprintf(_T("get proc address: %s - 0x%p\n"), CSDLL_LIBRARY_ADDSUM, fnAddsum);
    _tprintf(_T("From memory: %d\n"), fnAddsum(1, 2));

    resourceInfo = MemoryFindResource(handle, MAKEINTRESOURCE(VS_VERSION_INFO), RT_VERSION);
    _tprintf(_T("MemoryFindResource returned 0x%p\n"), resourceInfo);

    resourceSize = MemorySizeofResource(handle, resourceInfo);
    resourceData = MemoryLoadResource(handle, resourceInfo);
    _tprintf(_T("Memory resource data: %ld bytes at 0x%p\n"), resourceSize, resourceData);

    MemoryLoadString(handle, 1, buffer, sizeof(buffer));
    _tprintf(_T("String1: %s\n"), buffer);

    MemoryLoadString(handle, 20, buffer, sizeof(buffer));
    _tprintf(_T("String2: %s\n"), buffer);

    MemoryFreeLibrary(handle);
}
#endif  // _WIN64

int TestCustomAllocAndFree(void)
{
    void *data;
    size_t size;
    int iret;

    PrintCWD();
    _tprintf(_T("%s\n"), _T("TestCustomAllocAndFree ......"));
    data = ReadLibrary(&size);
    if (data == NULL)
    {
        iret = 1;
        goto exit;
    }

    _tprintf(_T("Test MemoryLoadLibraryEx after initially failing allocation function\n"));
    TestFailingAllocation(data, size);
    _tprintf(_T("Test cleanup after MemoryLoadLibraryEx with failing allocation function\n"));
    TestCleanupAfterFailingAllocation(data, size);
    _tprintf(_T("Test custom free function after MemoryLoadLibraryEx\n"));
    TestFreeAfterDefaultAlloc(data, size);
#ifdef _WIN64
    _tprintf(_T("Test allocating in high memory\n"));
    TestAllocHighMemory(data, size);
#endif

    iret = 0;

exit:
    if (data != NULL)
    {
        free(data);
        data = NULL;
    }
    return iret;
}

#if defined (__cplusplus) || defined (c_plusplus)
} /* End of extern "C" */
#endif
